import numpy as np
import torch
import os
import json
import glob
from typing import List, Dict, Optional
from dataclasses import dataclass
from pathlib import Path


@dataclass
class Checkpoint:
    """Checkpoint representation with metadata"""
    id: str
    version: str
    trigger_tokens: List[str]
    subject_types: List[str]
    description: str
    metadata: Dict
    path: str  # Path to the checkpoint file
    embedding: Optional[torch.Tensor] = None


def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j."""
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)
    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)
    if len(a.shape) == 1:
        a = a.unsqueeze(0)
    if len(b.shape) == 1:
        b = b.unsqueeze(0)
    
    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))


def fetch_all_checkpoints(data_folder: str = "./data/checkpoints") -> List[Checkpoint]:
    """
    Fetch all checkpoints from data folder structure.
    
    Folder structure:
    data/checkpoints/
    ├── bear/
    │   ├── v1/
    │   │   ├── checkpoint.safetensors (or .ckpt, .pt)
    │   │   └── metadata.json
    │   ├── v2/
    │   │   ├── checkpoint.safetensors
    │   │   └── metadata.json
    │   └── v4/
    │       ├── checkpoint.safetensors
    │       └── metadata.json
    ├── cat/
    │   └── v1/
    │       ├── checkpoint.safetensors
    │       └── metadata.json
    └── ...
    
    metadata.json format:
    {
        "trigger_tokens": ["<bear-v4>"],
        "subject_types": ["bear", "teddy bear", "plush toy"],
        "description": "A small plush teddy bear, tan with a cream snout and belly",
        "created_at": "2025-04-11",
        "additional_info": {...}
    }
    """
    checkpoints = []
    data_path = Path(data_folder)
    
    # Find all metadata.json files
    metadata_files = list(data_path.glob("*/*/metadata.json"))
    
    for metadata_file in metadata_files:
        try:
            # Parse the path to get subject and version
            version_dir = metadata_file.parent
            subject_dir = version_dir.parent
            
            subject_name = subject_dir.name
            version = version_dir.name
            
            # Load metadata
            with open(metadata_file, 'r') as f:
                metadata = json.load(f)
            
            # Find the checkpoint file (supports multiple extensions)
            checkpoint_extensions = ['.safetensors', '.ckpt', '.pt', '.pth', '.bin']
            checkpoint_file = None
            
            for ext in checkpoint_extensions:
                potential_file = version_dir / f"checkpoint{ext}"
                if potential_file.exists():
                    checkpoint_file = str(potential_file)
                    break
                # Also check for model.safetensors or other naming conventions
                potential_file = version_dir / f"model{ext}"
                if potential_file.exists():
                    checkpoint_file = str(potential_file)
                    break
            
            if not checkpoint_file:
                print(f"Warning: No checkpoint file found in {version_dir}")
                continue
            
            # Create checkpoint object
            checkpoint_id = f"{subject_name}-{version}"
            
            checkpoint = Checkpoint(
                id=checkpoint_id,
                version=version,
                trigger_tokens=metadata.get("trigger_tokens", [f"<{checkpoint_id}>"]),
                subject_types=metadata.get("subject_types", [subject_name]),
                description=metadata.get("description", f"Checkpoint for {subject_name} {version}"),
                metadata={
                    "created_at": metadata.get("created_at"),
                    "subject": subject_name,
                    "version_num": version.replace("v", ""),
                    **metadata.get("additional_info", {})
                },
                path=checkpoint_file
            )
            
            checkpoints.append(checkpoint)
            
        except Exception as e:
            print(f"Error loading checkpoint from {metadata_file}: {e}")
            continue
    
    print(f"Loaded {len(checkpoints)} checkpoints from {data_folder}")
    return checkpoints


def fetch_checkpoint_embeddings(
    checkpoints: List[Checkpoint],
    embedding_type: str = "openai",
    cache_dir: str = "./cache/embeddings",
    force_recompute: bool = False
) -> torch.Tensor:
    """
    Compute or load cached embeddings for checkpoint descriptions.
    
    Args:
        checkpoints: List of checkpoint objects
        embedding_type: Type of embedding model to use
        cache_dir: Directory to cache embeddings
        force_recompute: Force recomputation even if cache exists
    """
    os.makedirs(cache_dir, exist_ok=True)
    cache_file = os.path.join(cache_dir, f"checkpoint_embeddings_{embedding_type}.pt")
    
    # Try to load from cache
    if not force_recompute and os.path.exists(cache_file):
        try:
            cached_data = torch.load(cache_file)
            cached_ids = cached_data['checkpoint_ids']
            cached_embeddings = cached_data['embeddings']
            
            # Check if all checkpoints are in cache
            current_ids = [c.id for c in checkpoints]
            if set(current_ids) == set(cached_ids):
                # Reorder to match current checkpoint order
                id_to_embedding = dict(zip(cached_ids, cached_embeddings))
                embeddings = torch.stack([id_to_embedding[cid] for cid in current_ids])
                
                # Assign to checkpoint objects
                for i, checkpoint in enumerate(checkpoints):
                    checkpoint.embedding = embeddings[i]
                
                print(f"Loaded embeddings from cache: {cache_file}")
                return embeddings
        except Exception as e:
            print(f"Error loading cache, recomputing: {e}")
    
    # Compute embeddings
    embedding_cls = get_embedding_cls(embedding_type)
    descriptions = [c.description for c in checkpoints]
    embeddings = compute_embeddings(descriptions, embedding_cls=embedding_cls)
    
    # Store embeddings in checkpoint objects
    for i, checkpoint in enumerate(checkpoints):
        checkpoint.embedding = embeddings[i]
    
    # Save to cache
    try:
        cache_data = {
            'checkpoint_ids': [c.id for c in checkpoints],
            'embeddings': embeddings,
            'embedding_type': embedding_type
        }
        torch.save(cache_data, cache_file)
        print(f"Saved embeddings to cache: {cache_file}")
    except Exception as e:
        print(f"Error saving cache: {e}")
    
    return embeddings


def retrieve_top_k_checkpoints(
    prompt: str,
    top_k: int = 10,
    embedding_type: str = "openai",
    checkpoints_cache: Optional[List[Checkpoint]] = None,
    data_folder: str = "./data/checkpoints",
    pinecone: bool = False,
    debug: bool = False
) -> List[Checkpoint]:
    """
    Stage 1: Initial Retrieval - Get top-K checkpoints based on semantic similarity.
    
    This is the only stage we implement. Stages 2 and 3 (metadata filtering and 
    clarification) will be handled by Gemini API.
    
    Returns:
        List of top-K checkpoints sorted by relevance
    """
    
    # Fetch checkpoints from data folder
    if checkpoints_cache is None:
        checkpoints = fetch_all_checkpoints(data_folder)
    else:
        checkpoints = checkpoints_cache
    
    if len(checkpoints) == 0:
        print("Warning: No checkpoints found!")
        return []
    
    # Compute prompt embedding
    embedding_cls = get_embedding_cls(embedding_type)
    prompt_embedding = compute_embeddings([prompt], embedding_cls=embedding_cls)
    
    if pinecone:
        assert os.getenv('PINECONE_KEY'), "PINECONE_KEY required for Pinecone"
        assert embedding_type == "openai", "Only OpenAI embeddings supported for Pinecone"
        
        # Query Pinecone for top-K indices
        top_k_indices = query(prompt_embedding[0].tolist(), top_k=top_k)
        ranked_checkpoints = [checkpoints[idx] for idx in top_k_indices]
        
    else:
        # Compute or load checkpoint embeddings
        checkpoint_embeddings = fetch_checkpoint_embeddings(
            checkpoints, 
            embedding_type=embedding_type
        )
        
        # Compute similarities
        cos_sim = cosine_similarity(prompt_embedding, checkpoint_embeddings).numpy()
        cos_sim = cos_sim.flatten()
        
        # Get top-K checkpoints
        top_k = min(top_k, len(checkpoints))
        top_k_indices = np.argsort(cos_sim)[::-1][:top_k]
        ranked_checkpoints = [checkpoints[idx] for idx in top_k_indices]
        
        # Add similarity scores to checkpoints for debugging
        for i, checkpoint in enumerate(ranked_checkpoints):
            checkpoint.similarity_score = cos_sim[top_k_indices[i]]
    
    if debug:
        print("\n=== Top-K Retrieval Results ===")
        for i, checkpoint in enumerate(ranked_checkpoints):
            print(f"\n--- Rank {i+1} ---")
            print(f"ID: {checkpoint.id}")
            print(f"Similarity Score: {getattr(checkpoint, 'similarity_score', 'N/A')}")
            print(f"Description: {checkpoint.description}")
            print(f"Trigger Tokens: {checkpoint.trigger_tokens}")
            print(f"Subject Types: {checkpoint.subject_types}")
            print(f"Path: {checkpoint.path}")
    
    return ranked_checkpoints
